/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.beam.sdk.testing; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.io.ByteStreams; import com.google.common.io.CountingInputStream; import com.google.common.io.CountingOutputStream; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.util.Arrays; import java.util.Collection; import java.util.List; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.NonDeterministicException; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.UnownedInputStream; import org.apache.beam.sdk.util.UnownedOutputStream; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; /** * Properties for use in {@link Coder} tests. These are implemented with junit assertions * rather than as predicates for the sake of error messages. * * <p>We serialize and deserialize the coder to make sure that any state information required by * the coder is preserved. This causes tests written such that coders that lose information during * serialization or change state during encoding/decoding will fail. */ public class CoderProperties { /** * All the contexts, for use in test cases. */ public static final List<Coder.Context> ALL_CONTEXTS = ImmutableList.of( Coder.Context.OUTER, Coder.Context.NESTED); /** * Verifies that for the given {@code Coder<T>}, and values of * type {@code T}, if the values are equal then the encoded bytes are equal, in any * {@code Coder.Context}. */ public static <T> void coderDeterministic( Coder<T> coder, T value1, T value2) throws Exception { for (Coder.Context context : ALL_CONTEXTS) { coderDeterministicInContext(coder, context, value1, value2); } } /** * Verifies that for the given {@code Coder<T>}, {@code Coder.Context}, and values of * type {@code T}, if the values are equal then the encoded bytes are equal. */ public static <T> void coderDeterministicInContext( Coder<T> coder, Coder.Context context, T value1, T value2) throws Exception { try { coder.verifyDeterministic(); } catch (NonDeterministicException e) { fail("Expected that the coder is deterministic"); } assertThat("Expected that the passed in values are equal()", value1, equalTo(value2)); assertThat( encode(coder, context, value1), equalTo(encode(coder, context, value2))); } /** * Verifies that for the given {@code Coder<T>}, * and value of type {@code T}, encoding followed by decoding yields an * equal value of type {@code T}, in any {@code Coder.Context}. */ public static <T> void coderDecodeEncodeEqual( Coder<T> coder, T value) throws Exception { for (Coder.Context context : ALL_CONTEXTS) { coderDecodeEncodeEqualInContext(coder, context, value); } } /** * Verifies that for the given {@code Coder<T>}, {@code Coder.Context}, * and value of type {@code T}, encoding followed by decoding yields an * equal value of type {@code T}. */ public static <T> void coderDecodeEncodeEqualInContext( Coder<T> coder, Coder.Context context, T value) throws Exception { assertThat(decodeEncode(coder, context, value), equalTo(value)); } /** * Verifies that for the given {@code Coder<Collection<T>>}, * and value of type {@code Collection<T>}, encoding followed by decoding yields an * equal value of type {@code Collection<T>}, in any {@code Coder.Context}. */ public static <T, CollectionT extends Collection<T>> void coderDecodeEncodeContentsEqual( Coder<CollectionT> coder, CollectionT value) throws Exception { for (Coder.Context context : ALL_CONTEXTS) { coderDecodeEncodeContentsEqualInContext(coder, context, value); } } /** * Verifies that for the given {@code Coder<Collection<T>>}, * and value of type {@code Collection<T>}, encoding followed by decoding yields an * equal value of type {@code Collection<T>}, in the given {@code Coder.Context}. */ @SuppressWarnings("unchecked") public static <T, CollectionT extends Collection<T>> void coderDecodeEncodeContentsEqualInContext( Coder<CollectionT> coder, Coder.Context context, CollectionT value) throws Exception { // Matchers.containsInAnyOrder() requires at least one element Collection<T> result = decodeEncode(coder, context, value); if (value.isEmpty()) { assertThat(result, emptyIterable()); } else { // This is the only Matchers.containInAnyOrder() overload that takes literal values assertThat(result, containsInAnyOrder((T[]) value.toArray())); } } /** * Verifies that for the given {@code Coder<Collection<T>>}, * and value of type {@code Collection<T>}, encoding followed by decoding yields an * equal value of type {@code Collection<T>}, in any {@code Coder.Context}. */ public static <T, IterableT extends Iterable<T>> void coderDecodeEncodeContentsInSameOrder( Coder<IterableT> coder, IterableT value) throws Exception { for (Coder.Context context : ALL_CONTEXTS) { CoderProperties.<T, IterableT>coderDecodeEncodeContentsInSameOrderInContext( coder, context, value); } } /** * Verifies that for the given {@code Coder<Iterable<T>>}, * and value of type {@code Iterable<T>}, encoding followed by decoding yields an * equal value of type {@code Collection<T>}, in the given {@code Coder.Context}. */ @SuppressWarnings("unchecked") public static <T, IterableT extends Iterable<T>> void coderDecodeEncodeContentsInSameOrderInContext( Coder<IterableT> coder, Coder.Context context, IterableT value) throws Exception { Iterable<T> result = decodeEncode(coder, context, value); // Matchers.contains() requires at least one element if (Iterables.isEmpty(value)) { assertThat(result, emptyIterable()); } else { // This is the only Matchers.contains() overload that takes literal values assertThat(result, contains((T[]) Iterables.toArray(value, Object.class))); } } /** * Verifies that the given {@code Coder<T>} can be correctly serialized and * deserialized. */ public static <T> void coderSerializable(Coder<T> coder) { SerializableUtils.ensureSerializable(coder); } /** * Verifies that for the given {@code Coder<T>} and values of * type {@code T}, the values are equal if and only if the * encoded bytes are equal. */ public static <T> void coderConsistentWithEquals( Coder<T> coder, T value1, T value2) throws Exception { for (Coder.Context context : ALL_CONTEXTS) { CoderProperties.<T>coderConsistentWithEqualsInContext(coder, context, value1, value2); } } /** * Verifies that for the given {@code Coder<T>}, {@code Coder.Context}, and * values of type {@code T}, the values are equal if and only if the * encoded bytes are equal, in any {@code Coder.Context}. */ public static <T> void coderConsistentWithEqualsInContext( Coder<T> coder, Coder.Context context, T value1, T value2) throws Exception { assertEquals( value1.equals(value2), Arrays.equals( encode(coder, context, value1), encode(coder, context, value2))); } /** * Verifies that for the given {@code Coder<T>} and values of * type {@code T}, the structural values are equal if and only if the * encoded bytes are equal. */ public static <T> void structuralValueConsistentWithEquals( Coder<T> coder, T value1, T value2) throws Exception { for (Coder.Context context : ALL_CONTEXTS) { CoderProperties.<T>structuralValueConsistentWithEqualsInContext( coder, context, value1, value2); } } /** * Verifies that for the given {@code Coder<T>}, {@code Coder.Context}, and * values of type {@code T}, the structural values are equal if and only if the * encoded bytes are equal, in any {@code Coder.Context}. */ public static <T> void structuralValueConsistentWithEqualsInContext( Coder<T> coder, Coder.Context context, T value1, T value2) throws Exception { assertEquals( coder.structuralValue(value1).equals(coder.structuralValue(value2)), Arrays.equals( encode(coder, context, value1), encode(coder, context, value2))); } /** * Verifies that for the given {@code Coder<T>} and value of type {@code T}, * the structural value is equal to the structural value yield by encoding * and decoding the original value. * * <p>This is useful to test the correct implementation of a Coder structural * equality with values that don't implement the equals contract. */ public static <T> void structuralValueDecodeEncodeEqual( Coder<T> coder, T value) throws Exception { for (Coder.Context context : ALL_CONTEXTS) { CoderProperties.<T>structuralValueDecodeEncodeEqualInContext( coder, context, value); } } /** * Verifies that for the given {@code Coder<T>}, {@code Coder.Context}, * and value of type {@code T}, the structural value is equal to the * structural value yield by encoding and decoding the original value, * in any {@code Coder.Context}. */ public static <T> void structuralValueDecodeEncodeEqualInContext( Coder<T> coder, Coder.Context context, T value) throws Exception { assertEquals( coder.structuralValue(value), coder.structuralValue(decodeEncode(coder, context, value))); } private static final String DECODING_WIRE_FORMAT_MESSAGE = "Decoded value from known wire format does not match expected value." + " This probably means that this Coder no longer correctly decodes" + " a prior wire format. Changing the wire formats this Coder can read" + " should be avoided, as it is likely to cause breakage."; public static <T> void coderDecodesBase64(Coder<T> coder, String base64Encoding, T value) throws Exception { assertThat(DECODING_WIRE_FORMAT_MESSAGE, CoderUtils.decodeFromBase64(coder, base64Encoding), equalTo(value)); } public static <T> void coderDecodesBase64( Coder<T> coder, List<String> base64Encodings, List<T> values) throws Exception { assertThat("List of base64 encodings has different size than List of values", base64Encodings.size(), equalTo(values.size())); for (int i = 0; i < base64Encodings.size(); i++) { coderDecodesBase64(coder, base64Encodings.get(i), values.get(i)); } } private static final String ENCODING_WIRE_FORMAT_MESSAGE = "Encoded value does not match expected wire format." + " Changing the wire format should be avoided, as it is likely to cause breakage." + " If you truly intend to change the wire format for this Coder," + " See org.apache.beam.sdk.coders.PrintBase64Encoding for how to generate" + " new test data."; public static <T> void coderEncodesBase64(Coder<T> coder, T value, String base64Encoding) throws Exception { assertThat(ENCODING_WIRE_FORMAT_MESSAGE, CoderUtils.encodeToBase64(coder, value), equalTo(base64Encoding)); } public static <T> void coderEncodesBase64( Coder<T> coder, List<T> values, List<String> base64Encodings) throws Exception { assertThat("List of base64 encodings has different size than List of values", base64Encodings.size(), equalTo(values.size())); for (int i = 0; i < base64Encodings.size(); i++) { coderEncodesBase64(coder, values.get(i), base64Encodings.get(i)); } } @SuppressWarnings("unchecked") public static <T, IterableT extends Iterable<T>> void coderDecodesBase64ContentsEqual( Coder<IterableT> coder, String base64Encoding, IterableT expected) throws Exception { IterableT result = CoderUtils.decodeFromBase64(coder, base64Encoding); if (Iterables.isEmpty(expected)) { assertThat(ENCODING_WIRE_FORMAT_MESSAGE, result, emptyIterable()); } else { assertThat(ENCODING_WIRE_FORMAT_MESSAGE, result, containsInAnyOrder((T[]) Iterables.toArray(expected, Object.class))); } } public static <T, IterableT extends Iterable<T>> void coderDecodesBase64ContentsEqual( Coder<IterableT> coder, List<String> base64Encodings, List<IterableT> expected) throws Exception { assertThat("List of base64 encodings has different size than List of values", base64Encodings.size(), equalTo(expected.size())); for (int i = 0; i < base64Encodings.size(); i++) { coderDecodesBase64ContentsEqual(coder, base64Encodings.get(i), expected.get(i)); } } ////////////////////////////////////////////////////////////////////////// @VisibleForTesting static <T> byte[] encode( Coder<T> coder, Coder.Context context, T value) throws CoderException, IOException { @SuppressWarnings("unchecked") Coder<T> deserializedCoder = SerializableUtils.clone(coder); ByteArrayOutputStream os = new ByteArrayOutputStream(); deserializedCoder.encode(value, new UnownedOutputStream(os), context); return os.toByteArray(); } @VisibleForTesting static <T> T decode( Coder<T> coder, Coder.Context context, byte[] bytes) throws CoderException, IOException { @SuppressWarnings("unchecked") Coder<T> deserializedCoder = SerializableUtils.clone(coder); byte[] buffer; if (context == Coder.Context.NESTED) { buffer = new byte[bytes.length + 1]; System.arraycopy(bytes, 0, buffer, 0, bytes.length); buffer[bytes.length] = 1; } else { buffer = bytes; } CountingInputStream cis = new CountingInputStream(new ByteArrayInputStream(buffer)); T value = deserializedCoder.decode(new UnownedInputStream(cis), context); assertThat("consumed bytes equal to encoded bytes", cis.getCount(), equalTo((long) bytes.length)); return value; } private static <T> T decodeEncode(Coder<T> coder, Coder.Context context, T value) throws CoderException, IOException { return decode(coder, context, encode(coder, context, value)); } /** * A utility method that passes the given (unencoded) elements through * coder's registerByteSizeObserver() and encode() methods, and confirms * they are mutually consistent. This is useful for testing coder * implementations. */ public static <T> void testByteCount(Coder<T> coder, Coder.Context context, T[] elements) throws Exception { TestElementByteSizeObserver observer = new TestElementByteSizeObserver(); try (CountingOutputStream os = new CountingOutputStream(ByteStreams.nullOutputStream())) { for (T elem : elements) { coder.registerByteSizeObserver(elem, observer); coder.encode(elem, os, context); observer.advance(); } long expectedLength = os.getCount(); if (!context.isWholeStream) { assertEquals(expectedLength, observer.getSum()); } assertEquals(elements.length, observer.getCount()); } } /** * An {@link ElementByteSizeObserver} that records the observed element sizes for testing * purposes. */ public static class TestElementByteSizeObserver extends ElementByteSizeObserver { private long currentSum = 0; private long count = 0; @Override protected void reportElementSize(long elementByteSize) { count++; currentSum += elementByteSize; } public double getMean() { return ((double) currentSum) / count; } public long getSum() { return currentSum; } public long getCount() { return count; } public void reset() { currentSum = 0; count = 0; } public long getSumAndReset() { long returnValue = currentSum; reset(); return returnValue; } } }